{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "\n",
    "import sys\n",
    "import warnings\n",
    "\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "sys.path.append(\"../\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modules.data.conll2003.prc import conll2003_preprocess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = \"/home/eartemov/ae/work/conll2003/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f14d2a1ce44947ce98b9f430cf82caf1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Process /home/eartemov/ae/work/conll2003/eng.train', max=2195…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fa6fddbba84a4de78a48e7503af8d616",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Process /home/eartemov/ae/work/conll2003/eng.testa', max=5504…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e3c3e6e2ebdb4e51ba4051a1499e53cf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Process /home/eartemov/ae/work/conll2003/eng.testb', max=5035…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "conll2003_preprocess(data_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## IO markup"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modules.data import bert_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The pre-trained model you are loading is a cased model but you have not set `do_lower_case` to False. We are setting `do_lower_case=False` for you but you may want to check this behavior.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Creating labels vocabs', max=6973, style=ProgressStyle(descri…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    }
   ],
   "source": [
    "data = bert_data.LearnData.create(\n",
    "    train_df_path=\"/home/eartemov/ae/work/conll2003/eng.train.train.csv\",\n",
    "    valid_df_path=\"/home/eartemov/ae/work/conll2003/eng.testa.dev.csv\",\n",
    "    idx2labels_path=\"/home/eartemov/ae/work/conll2003/idx2labels2.txt\",\n",
    "    clear_cache=True,\n",
    "    model_name=\"bert-base-cased\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modules.models.bert_models import BERTBiLSTMCRF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = BERTBiLSTMCRF.create(\n",
    "    len(data.train_ds.idx2label), model_name=\"bert-base-cased\",\n",
    "    lstm_dropout=0., crf_dropout=0.3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modules.train.train import NerLearner"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner = NerLearner(\n",
    "    model, data, \"/home/eartemov/ae/work/models/conll2003-BERTBiLSTMCRF-base-IO.cpt\", t_total=num_epochs * len(data.train_dl))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2235023"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.get_n_trainable_params()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "learner.fit(epochs=num_epochs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modules.data.bert_data import get_data_loader_for_predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "dl = get_data_loader_for_predict(data, df_path=data.valid_ds.config[\"df_path\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Predicting', max=109, style=ProgressStyle(description_width='…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    }
   ],
   "source": [
    "preds = learner.predict(dl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn_crfsuite.metrics import flat_classification_report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modules.analyze_utils.utils import bert_labels2tokens, voting_choicer\n",
    "from modules.analyze_utils.plot_metrics import get_bert_span_report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_tokens, pred_labels = bert_labels2tokens(dl, preds)\n",
    "true_tokens, true_labels = bert_labels2tokens(dl, [x.bert_labels for x in dl.dataset])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert pred_tokens == true_tokens\n",
    "tokens_report = flat_classification_report(true_labels, pred_labels, labels=data.train_ds.idx2label[4:], digits=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "       I_ORG     0.9536    0.9431    0.9483      2005\n",
      "         I_O     0.9973    0.9969    0.9971     41651\n",
      "      I_MISC     0.9163    0.9155    0.9159      1255\n",
      "       I_PER     0.9842    0.9863    0.9853      2848\n",
      "       I_LOC     0.9674    0.9584    0.9629      1922\n",
      "\n",
      "   micro avg     0.9916    0.9905    0.9911     49681\n",
      "   macro avg     0.9638    0.9600    0.9619     49681\n",
      "weighted avg     0.9916    0.9905    0.9911     49681\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(tokens_report)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modules.data.bert_data import get_data_loader_for_predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "dl = get_data_loader_for_predict(data, df_path=\"/home/eartemov/ae/work/conll2003/eng.testb.dev.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Predicting', max=98, style=ProgressStyle(description_width='i…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    }
   ],
   "source": [
    "preds = learner.predict(dl)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn_crfsuite.metrics import flat_classification_report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "from modules.analyze_utils.utils import bert_labels2tokens, voting_choicer\n",
    "from modules.analyze_utils.plot_metrics import get_bert_span_report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_tokens, pred_labels = bert_labels2tokens(dl, preds)\n",
    "true_tokens, true_labels = bert_labels2tokens(dl, [x.bert_labels for x in dl.dataset])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert pred_tokens == true_tokens\n",
    "tokens_report = flat_classification_report(true_labels, pred_labels, labels=data.train_ds.idx2label[4:], digits=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "       I_ORG     0.8933    0.9081    0.9006      2360\n",
      "         I_O     0.9953    0.9921    0.9937     37492\n",
      "      I_MISC     0.8260    0.8088    0.8173       910\n",
      "       I_PER     0.9727    0.9810    0.9768      2683\n",
      "       I_LOC     0.9219    0.9300    0.9259      1815\n",
      "\n",
      "   micro avg     0.9822    0.9808    0.9815     45260\n",
      "   macro avg     0.9218    0.9240    0.9229     45260\n",
      "weighted avg     0.9823    0.9808    0.9815     45260\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(tokens_report)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
